/*
 * Copyright (c) 2023-2023 elsfs Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.elsfs.cloud.common.mybatis.interceptor;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.cache.impl.PerpetualCache;
import org.apache.ibatis.executor.CachingExecutor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.elsfs.cloud.common.annotations.CipherField;
import org.elsfs.cloud.common.annotations.EnableCipher;
import org.elsfs.cloud.common.annotations.ICipherAlgorithm;
import org.springframework.util.ReflectionUtils;

/**
 * 基于mybatis拦截器实现查询字段脱敏，敏感字段加解密
 *
 * @author zeng
 */
@Intercepts({
  @Signature(
      type = Executor.class,
      method = "update",
      args = {MappedStatement.class, Object.class}),
  @Signature(
      type = Executor.class,
      method = "query",
      args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
  @Signature(
      type = Executor.class,
      method = "query",
      args = {
        MappedStatement.class,
        Object.class,
        RowBounds.class,
        ResultHandler.class,
        CacheKey.class,
        BoundSql.class
      }),
})
@Slf4j
public class MybatisSensitiveInterceptor implements Interceptor {

  private final ThreadLocal<Map<Object, Map<Field, CipherValue>>> cipherFieldMapLocal =
      ThreadLocal.withInitial(ConcurrentHashMap::new);
  private static final String SELECT_KEY = "!selectKey";
  private static final int EXECUTOR_PARAMETER_COUNT_4 = 4;
  private static final int MAPPED_STATEMENT_INDEX = 0;
  private static final int PARAMETER_INDEX = 1;
  private static final int ROW_BOUNDS_INDEX = 2;
  private static final int CACHE_KEY_INDEX = 4;

  private final Map<String, ICipherAlgorithm> algorithmMap = new ConcurrentHashMap<>(2);

  @Override
  public Object intercept(Invocation invocation) throws Throwable {
    Object[] args = invocation.getArgs();
    MappedStatement statement = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
    Object parameter = args[PARAMETER_INDEX];
    // 处理selectKey的特殊情况
    handleSelectKey(statement, parameter);
    EnableCipher enableCipher = getEnableCipher(statement);
    if (enableCipher == null) {
      return invocation.proceed();
    }
    // 获取所有的注解参数字段，根据每个字段对应的算法进行加密或解密
    handleParameter(parameter, enableCipher.parameter());
    // 判断是否命中缓存
    // TODO 缓存处理
    //  boolean hitCache = hitCache(invocation, parameter);
    // 执行proceed
    Object proceed;
    try {
      proceed = invocation.proceed();
    } catch (Exception e) {
      throw new RuntimeException(e);
    } finally {
      // 还原参数
      revertParameter(parameter);
    }
    // 执行结果处理
    // TODO 缓存处理
    // return hitCache ? proceed : handleResult(proceed, enableCipher.result());
    return handleResult(proceed, enableCipher.result());
  }

  @Override
  public Object plugin(Object target) {
    return Plugin.wrap(target, this);
  }

  @Override
  public void setProperties(Properties properties) {}

  /**
   * 获取mapper上的EnableCipher注解
   *
   * @param statement MappedStatement对象
   * @return 返回EnableCipher注解对象
   */
  private EnableCipher getEnableCipher(MappedStatement statement) {
    String namespace = statement.getId();
    String className = namespace.substring(0, namespace.lastIndexOf("."));
    String methodName = statement.getId().substring(statement.getId().lastIndexOf(".") + 1);
    Method[] methods;
    try {
      methods = Class.forName(className).getMethods();
      for (Method method : methods) {
        if (method.getName().equals(methodName)) {
          if (method.isAnnotationPresent(EnableCipher.class)) {
            return method.getAnnotation(EnableCipher.class);
          }
        }
      }
    } catch (ClassNotFoundException e) {
      LOGGER.error("get @EnableCipher from {} error!", namespace);
    }
    return null;
  }

  /**
   * 处理selectKey
   *
   * @param statement statement
   * @param parameter parameter
   */
  private void handleSelectKey(MappedStatement statement, Object parameter) {
    SqlCommandType commandType = statement.getSqlCommandType();
    if (commandType == SqlCommandType.SELECT && statement.getId().endsWith(SELECT_KEY)) {
      revertParameter(parameter);
    }
  }

  /**
   * 处理参数 是否加解密
   *
   * @param parameter 输入参数
   * @param cipherType 加解密方式
   */
  private void handleParameter(Object parameter, EnableCipher.CipherType cipherType) {
    if (cipherType == EnableCipher.CipherType.NONE) {
      return;
    }
    Map<Object, Map<Field, CipherValue>> cipherMap = new HashMap<>();
    if (parameter instanceof Map<?, ?> parameterMap) {
      Map<Object, Object> map = filterRepeatValueMap(parameterMap);
      map.forEach(
          (k, v) -> {
            Map<Object, Map<Field, CipherValue>> valueMap;
            if (v instanceof Collection<?> v1) {
              valueMap = handleCipher(v1, cipherType);
            } else {
              valueMap = handleCipher(Collections.singleton(v), cipherType);
            }
            if (valueMap != null && !valueMap.isEmpty()) {
              cipherMap.putAll(valueMap);
            }
          });
    } else {
      Map<Object, Map<Field, CipherValue>> valueMap =
          handleCipher(Collections.singleton(parameter), cipherType);
      if (valueMap != null && !valueMap.isEmpty()) {
        cipherMap.putAll(valueMap);
      }
    }
    // ThreadLocal临时保存处理过的字段值
    if (!cipherMap.isEmpty()) {
      cipherFieldMapLocal.get().putAll(cipherMap);
      // 可以打印参数加密前和加密后的值 cipherMap
    }
  }

  /**
   * 处理结果
   *
   * @param result 执行结果对象
   * @param cipherType 加解密方式
   * @return 返回处理后的结果对象
   */
  private Object handleResult(Object result, EnableCipher.CipherType cipherType) {
    if (cipherType == EnableCipher.CipherType.NONE) {
      return result;
    }
    if (result instanceof Collection<?> r) {
      handleCipher(r, cipherType);
    } else {
      handleCipher(Collections.singleton(result), cipherType);
    }
    return result;
  }

  /**
   * 还原参数
   *
   * @param parameter p
   */
  private void revertParameter(Object parameter) {
    final Map<Object, Map<Field, CipherValue>> cipherFieldMap = cipherFieldMapLocal.get();
    if (cipherFieldMap.isEmpty()) {
      return;
    }
    if (parameter instanceof Map<?, ?> map) {
      Map<Object, Object> parameterMap = filterRepeatValueMap(map);
      parameterMap.forEach(
          (k, v) -> {
            if (v instanceof Collection<?> v1) {
              v1.stream()
                  .filter(Objects::nonNull)
                  .forEach(
                      obj -> {
                        Map<Field, CipherValue> valueMap = cipherFieldMap.get(obj);
                        if (Objects.nonNull(valueMap)) {
                          valueMap.forEach(
                              (field, cipher) ->
                                  ReflectionUtils.setField(field, obj, cipher.getBefore()));
                        }
                      });
            } else {
              Map<Field, CipherValue> valueMap = cipherFieldMap.get(v);
              if (Objects.nonNull(valueMap)) {
                valueMap.forEach(
                    (field, cipher) -> ReflectionUtils.setField(field, v, cipher.getBefore()));
              }
            }
          });
    } else {
      Map<Field, CipherValue> valueMap = cipherFieldMap.get(parameter);
      if (Objects.nonNull(valueMap)) {
        valueMap.forEach(
            (field, cipher) -> ReflectionUtils.setField(field, parameter, cipher.getBefore()));
      }
    }
    cipherFieldMap.clear();
  }

  /**
   * 查询语句是否命中缓存
   *
   * @param invocation 拦截器方法对象
   * @param parameter 处理过后的参数对象
   * @return 是否命中缓存
   */
  private boolean hitCache(Invocation invocation, Object parameter) throws IllegalAccessException {
    Object[] args = invocation.getArgs();
    MappedStatement mappedStatement = (MappedStatement) args[MAPPED_STATEMENT_INDEX];
    Executor executor = (Executor) invocation.getTarget();
    SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
    // 非查询语句直接返回false
    if (!SqlCommandType.SELECT.equals(sqlCommandType)) {
      return false;
    }
    RowBounds rowBounds = (RowBounds) args[ROW_BOUNDS_INDEX];
    BoundSql boundSql;
    CacheKey cacheKey;
    if (args.length == EXECUTOR_PARAMETER_COUNT_4) {
      boundSql = mappedStatement.getBoundSql(parameter);
      cacheKey = executor.createCacheKey(mappedStatement, parameter, rowBounds, boundSql);
    } else {
      cacheKey = (CacheKey) args[CACHE_KEY_INDEX];
    }
    Executor baseExecutor;
    if (executor instanceof CachingExecutor) {
      Field field = ReflectionUtils.findField(CachingExecutor.class, "delegate");
      assert field != null;
      field.setAccessible(true);
      baseExecutor = (Executor) field.get(executor);
    } else {
      baseExecutor = (Executor) invocation.getTarget();
    }
    Field field = ReflectionUtils.findField(CachingExecutor.class, "localCache");
    assert field != null;
    field.setAccessible(true);
    PerpetualCache localCache = (PerpetualCache) field.get(baseExecutor);
    return Objects.nonNull(localCache.getObject(cacheKey));
  }

  /**
   * 获取算法对象
   *
   * @param value 加解密算法子类
   * @return 返回算法对象
   */
  private ICipherAlgorithm getCipherAlgorithm(Class<? extends ICipherAlgorithm> value) {
    String canonicalName = value.getCanonicalName();
    if (algorithmMap.containsKey(canonicalName)) {
      return algorithmMap.get(canonicalName);
    }
    try {
      ICipherAlgorithm algorithm = value.getDeclaredConstructor().newInstance();
      algorithmMap.put(value.getName(), algorithm);
      return algorithm;
    } catch (Exception e) {
      throw new RuntimeException("init ICipherAlgorithm error", e);
    }
  }

  /**
   * 加解密操作对象
   *
   * @param collection 输入参数
   * @param cipherType 加解密方式
   * @return 返回已处理字段的处理前后值
   */
  private Map<Object, Map<Field, CipherValue>> handleCipher(
      Collection<?> collection, EnableCipher.CipherType cipherType) {
    if (collection == null || collection.isEmpty()) {
      return null;
    }
    // 遍历参数，处理加解密
    Map<Object, Map<Field, CipherValue>> result = new HashMap<>();
    collection.forEach(
        object -> {
          Map<Field, CipherValue> valueMap = new HashMap<>();
          this.getFields(object).stream()
              .filter(
                  field ->
                      field.isAnnotationPresent(CipherField.class)
                          && field.getType() == String.class)
              .forEach(
                  field -> {
                    CipherField cipherField = field.getAnnotation(CipherField.class);
                    ICipherAlgorithm algorithm = getCipherAlgorithm(cipherField.value());
                    String value = (String) getField(field, object);
                    if (Objects.nonNull(value)) {
                      String algorithmValue = null;
                      if (cipherType == EnableCipher.CipherType.ENCRYPT) {
                        algorithmValue = algorithm.encrypt(value);
                      }
                      if (cipherType == EnableCipher.CipherType.DECRYPT) {
                        algorithmValue = algorithm.decrypt(value);
                      }
                      if (Objects.nonNull(algorithmValue)) {
                        ReflectionUtils.setField(field, object, algorithmValue);
                        valueMap.put(field, new CipherValue(value, algorithmValue));
                      }
                    }
                  });
          if (!valueMap.isEmpty()) {
            result.put(object, valueMap);
          }
        });
    return result;
  }

  private Map<Object, Object> filterRepeatValueMap(Map<?, ?> parameter) {
    Set<Integer> hashCodeSet = new HashSet<>();
    return (parameter)
        .entrySet().stream()
            .filter(e -> Objects.nonNull(e.getValue()))
            .filter(
                r -> {
                  if (!hashCodeSet.contains(r.getValue().hashCode())) {
                    hashCodeSet.add(r.getValue().hashCode());
                    return true;
                  }
                  return false;
                })
            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
  }

  private List<Field> getFields(Object obj) {
    List<Field> fieldList = new ArrayList<>();
    Class<?> tempClass = obj.getClass();
    while (tempClass != null) {
      fieldList.addAll(Arrays.asList(tempClass.getDeclaredFields()));
      tempClass = tempClass.getSuperclass();
    }
    return fieldList;
  }

  private Object getField(Field field, Object obj) {
    ReflectionUtils.makeAccessible(field);
    return ReflectionUtils.getField(field, obj);
  }

  @Getter
  static class CipherValue {
    String before;
    String after;

    public CipherValue(String before, String after) {
      // 处理前
      this.before = before;
      // 处理后
      this.after = after;
    }
  }
}
